{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp problem_types.masklm\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test setup\n",
    "import tensorflow as tf\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Masked Language Model(masklm)\n",
    "\n",
    "This module includes neccessary part to register Masked Language Model problem type."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "import pickle\n",
    "from typing import List\n",
    "\n",
    "import tensorflow as tf\n",
    "from loguru import logger\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling, pad_to_shape)\n",
    "from m3tl.special_tokens import PREDICT\n",
    "from m3tl.utils import (LabelEncoder, gather_indexes,\n",
    "                        get_label_encoder_save_path, get_phase,\n",
    "                        load_transformer_tokenizer, need_make_label_encoder)\n",
    "from transformers import TFSharedEmbeddings\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class MaskLM(tf.keras.Model):\n",
    "    \"\"\"Multimodal MLM top layer.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, params: BaseParams, problem_name: str, input_embeddings: tf.keras.layers.Layer=None, share_embedding=True) -> None:\n",
    "        super(MaskLM, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "\n",
    "        if share_embedding is False:\n",
    "            self.vocab_size = self.params.bert_config.vocab_size\n",
    "            self.share_embedding = False\n",
    "        else:\n",
    "            self.vocab_size = input_embeddings.shape[0]\n",
    "            embedding_size = input_embeddings.shape[-1]\n",
    "            share_valid = (self.params.bert_config.hidden_size ==\n",
    "                        embedding_size)\n",
    "            if not share_valid and self.params.share_embedding:\n",
    "                logger.warning(\n",
    "                    'Share embedding is enabled but hidden_size != embedding_size')\n",
    "            self.share_embedding = self.params.share_embedding & share_valid\n",
    "\n",
    "        if self.share_embedding:\n",
    "            self.share_embedding_layer = TFSharedEmbeddings(\n",
    "                vocab_size=self.vocab_size, hidden_size=input_embeddings.shape[1])\n",
    "            self.share_embedding_layer.build([1])\n",
    "            self.share_embedding_layer.weight = input_embeddings\n",
    "        else:\n",
    "            self.share_embedding_layer = tf.keras.layers.Dense(self.vocab_size)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        mode = get_phase()\n",
    "        features, hidden_features = inputs\n",
    "\n",
    "        # masking is done inside the model\n",
    "        seq_hidden_feature = hidden_features['seq']\n",
    "        if mode != PREDICT:\n",
    "            positions = features['masked_lm_positions']\n",
    "\n",
    "            # gather_indexes will flatten the seq hidden_states, we need to reshape\n",
    "            # back to 3d tensor\n",
    "            input_tensor = gather_indexes(seq_hidden_feature, positions)\n",
    "            shape_tensor = tf.shape(positions)\n",
    "            shape_list = tf.concat([shape_tensor, [seq_hidden_feature.shape.as_list()[-1]]], axis=0)\n",
    "            input_tensor = tf.reshape(input_tensor, shape=shape_list)\n",
    "            # set_shape to determin rank\n",
    "            input_tensor.set_shape(\n",
    "                [None, None, seq_hidden_feature.shape.as_list()[-1]])\n",
    "        else:\n",
    "            input_tensor = seq_hidden_feature\n",
    "        if self.share_embedding:\n",
    "            mlm_logits = self.share_embedding_layer(\n",
    "                input_tensor, mode='linear')\n",
    "        else:\n",
    "            mlm_logits = self.share_embedding_layer(input_tensor)\n",
    "        if mode != PREDICT:\n",
    "            mlm_labels = features['masked_lm_ids']\n",
    "            mlm_labels.set_shape([None, None])\n",
    "            mlm_labels = pad_to_shape(from_tensor=mlm_labels, to_tensor=mlm_logits, axis=1)\n",
    "            # compute loss\n",
    "            mlm_loss = empty_tensor_handling_loss(\n",
    "                mlm_labels,\n",
    "                mlm_logits,\n",
    "                tf.keras.losses.sparse_categorical_crossentropy\n",
    "            )\n",
    "            loss = nan_loss_handling(mlm_loss)\n",
    "            self.add_loss(loss)\n",
    "\n",
    "        return tf.nn.softmax(mlm_logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def masklm_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "\n",
    "    le_path = get_label_encoder_save_path(params=params, problem=problem)\n",
    "\n",
    "    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):\n",
    "        # fit and save label encoder\n",
    "        label_encoder = load_transformer_tokenizer(params.transformer_tokenizer_name, params.transformer_tokenizer_loading)\n",
    "        pickle.dump(label_encoder, open(le_path, 'wb'))\n",
    "        try:\n",
    "            params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.vocab))\n",
    "        except AttributeError:\n",
    "            # models like xlnet's vocab size can only be retrieved from config instead of tokenizer\n",
    "            params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)\n",
    "    else:\n",
    "        # models like xlnet's vocab size can only be retrieved from config instead of tokenizer\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=params.bert_config.vocab_size)\n",
    "        label_encoder = pickle.load(open(le_path, 'rb'))\n",
    "\n",
    "    return label_encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def masklm_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    # masklm is a special case since it modifies inputs\n",
    "    # for more standard implementation of masklm, please see premask_mlm\n",
    "    return None, None\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
